"""
Maps of annual dark ice duration from MODIS @ ~600 m for SW GrIS (on MAR grid system)

@author Andrew Tedstone (a.j.tedstone@bristol.ac.uk) September 2016
"""

import numpy as np
import xarray as xr
import pandas as pd

import georaster


""" PARAMETERS """

# If set to true will compute various basic metrics across whole time series,
# some of these are already computed by year during onset timing stuff and 
# exported to netCDF.
compute_basic = False

# thresholds for defining bare and dark ice
# as per Shimada et al.
thresh_bare_ice = 0.6
thresh_dark_ice = 0.45

# Window lengths for finding onset dates
window_bare = 7
window_dark = 7

# Number of positive finds needed in above window length for day to count as
# onset date. If you set the same as the window length then this many 
# consecutive days are required. If you set it to less then the days may be
# non-consecutive. For values < window_bare then the day is only set as the
# onset day if n_dark + n_bad_quality_obs == window_bare, i.e. no bare 
# ice/snow values will be allowed to occur in this time...do I want to allow 
# a tolerance of e.g. + 0.02?
n_bare = 3
n_dark = 3

""" END PARAMETERS """



"""
Load mask
"""
# Mask generated by GDAL from 90 m ice sheet mask
mask = georaster.SingleBandRaster('/scratch/MOD09GA.006.SW/GIMP_IceMask_MAR613mSW.tif')
# Set off-ice-sheet areas to nan
mask.r = np.where(mask.r > 0, 1, np.nan)
mask.r[0:250,0:300] = np.nan
mask.r[575:945,0:157] = np.nan
mask.r[1169:1417,0:150] = np.nan
mask.r[823:893,150:200] = np.nan
mask.r[830:860,197:214] = np.nan


"""
Compute basic stats, some of these are also computed year-by-year below
"""
if compute_basic == True:
	modis_ds = xr.open_mfdataset('/scratch/MOD09GA.006.SW/*2017*b1234q.nc', 
		chunks={'TIME':365})
	b01 = modis_ds.sur_refl_b01_1

	# Also enumerate MODIS dark extent, 0.4 corresponds 
	# approximately to band 1 0.5 reflectance threshold (see progress_slides.odp)
	modis_dark = b01.where((b01 < thresh_dark_ice)) \
		.notnull().resample(dim='TIME',freq='AS',how='sum')

	# Plot number of dark days
	#modis_dark.plot(col='TIME', vmin=0, vmax=100)

	# Dark days as % of total clear-sky observation days
	modis_good = modis_al.where((modis_al < 101) & (modis_al > 0)) \
		.notnull().resample(dim='TIME',freq='AS',how='sum')
	modis_dark_prop = (100 / modis_good) * modis_dark
	#figure(),modis_dark_prop.plot(col='TIME',vmin=0,vmax=100)

	# Average albedo of dark ice each year 
	# Could also threshold these by n.obs. per pixel to clean up the noise
	modis_dark_avg = modis_al.where((modis_al < 40) & (modis_al > 0)) \
		.resample(dim='TIME',freq='AS',how='mean',skipna=True)
	# And associated s.d.
	modis_dark_std = modis_al.where((modis_al < 40) & (modis_al > 0)) \
		.resample(dim='TIME',freq='AS',how='std')





# Temporary storage
all_bare = []
all_dark = []
all_alt = []
all_dark_duration = []
all_dark_first = []
all_bad_count = []
all_postsnow_bad_count = []
all_dark_win = []
all_bare_win = []

for year in range(2017, 2018):
	print(year)

	# Load surface reflectance. 
	ds_refl_yr = xr.open_mfdataset('/scratch/MOD09GA.006.SW/MOD09GA.' + str(year) + '.006.reprojb1234*.nc',
		chunks={'TIME':5})

	# Load in albedo in order to use cloud mask, need to slice as files begin in April
	ds_al_yr = xr.open_dataset('/scratch/MOD10A1.006.SW/MOD10A1.' + str(year) + '.006.reproj.nc',
		chunks={'TIME':5})
	time_slice = slice(str(year)+'-06-01', str(year)+'-08-30')
	al_yr = ds_al_yr.Snow_Albedo_Daily_Tile.sel(TIME=time_slice)
	al_yr.load()

	# Map bare ice, slice by time to be sure
	b02_yr = ds_refl_yr.sur_refl_b02_1.sel(TIME=time_slice)
	doy_b02 = b02_yr.groupby(b02_yr['TIME.dayofyear']).apply(lambda x: x).dayofyear	
	# Load dask array into memory
	b02_yr.load()
	bare_doys = doy_b02.where((b02_yr < thresh_bare_ice) & (al_yr < 101))
	consec_bare_doys = bare_doys.notnull().rolling(TIME=window_bare).sum()
	# Get count of bad quality days within each window (e.g. clouds)
	# Windows are right-aligned
	bare_qual = doy_b02.where((al_yr > 100)).notnull().rolling(TIME=window_bare).sum()
	bare_onset_win = doy_b02.where((consec_bare_doys >= n_bare) & \
		(consec_bare_doys + bare_qual == window_bare)) \
		.resample(dim='TIME', freq='AS', how='first')
	all_bare_win.append(bare_onset_win)

	# bare_onset_win is only set to the last day of the first window positively 
	# identified as bare
	# Here, find the first day of bare onset within the chosen window
	b = bare_doys.where((bare_doys['dayofyear'] <= bare_onset_win.isel(TIME=0)) & \
		(bare_doys['dayofyear'] > (bare_onset_win.isel(TIME=0)-window_bare)) & \
		(bare_doys.notnull()))
	bare_onset = doy_b02.where(b.notnull()).resample(dim='TIME', freq='AS', how='first')
	all_bare.append(bare_onset)

	b02_yr = None

	# Dark ice
	b01_yr = ds_refl_yr.sur_refl_b01_1.sel(TIME=time_slice)
	b01_yr.load()
	doy_b01 = b01_yr.groupby(b01_yr['TIME.dayofyear']).apply(lambda x: x).dayofyear	
	dark_doys = doy_b01.where((b01_yr < thresh_dark_ice) & (al_yr < 101))
	consec_dark_doys = dark_doys.notnull().rolling(TIME=window_dark).sum()
	# Get count of bad quality days within each window (e.g. clouds)
	dark_qual = doy_b01.where((al_yr > 100)).notnull().rolling(TIME=window_dark).sum()
	# First logic ensures min number of dark doys in window condition met
	# Second logic checks that no 'bare/snow' doys in there - only bad quality doys
	dark_onset_win = doy_b01.where((consec_dark_doys >= n_dark) & \
		(consec_dark_doys + dark_qual == window_dark)) \
		.resample(dim='TIME', freq='AS', how='first')
	all_dark_win.append(dark_onset_win)

	# dark_onset_win is only set to the last day of the first window positively 
	# identified as dark
	# Here, find the first day of dark onset within the chosen window
	d = dark_doys.where((dark_doys['dayofyear'] <= dark_onset_win.isel(TIME=0)) & \
		(dark_doys['dayofyear'] > (dark_onset_win.isel(TIME=0)-window_dark)) & \
		(dark_doys.notnull()))
	dark_onset = doy_b01.where(d.notnull()).resample(dim='TIME', freq='AS', how='first')
	all_dark.append(dark_onset)
	

	# Dark ice first annual appearance 
	dark_onset_first = doy_b01.where((b01_yr < thresh_dark_ice) & (al_yr < 101)) \
		.resample(dim='TIME', freq='AS', how='first')
	all_dark_first.append(dark_onset_first)

	# Dark counts
	dark_duration = b01_yr.where((b01_yr < thresh_dark_ice) & (al_yr < 101)) \
		.notnull().resample(dim='TIME',freq='AS',how='sum')
	all_dark_duration.append(dark_duration)

	# Cloud/Bad quality counts
	no_obs = al_yr.where(al_yr > 100).notnull().resample(dim='TIME', 
		freq='AS', how='sum')
	all_bad_count.append(no_obs)

	# Bad quality after snow clear date
	ndoy = al_yr.shape[0]
	bare_onset -= window_bare
	bare = np.repeat(bare_onset.values, ndoy, axis=0)
	for doy in range(0, ndoy):
		bare[doy,:,:] = np.where(bare[doy,:,:] <= doy+1, 1, 0)
	bare = xr.DataArray(bare, coords=[al_yr.TIME, al_yr.Y, al_yr.X], dims=['TIME', 'Y', 'X'])
	postsnow_bad_count = al_yr.where(bare).count(dim='TIME')
	postsnow_bad_count['TIME'] = pd.datetime(year, 1, 1)
	all_postsnow_bad_count.append(postsnow_bad_count)

	# Close file links
	b01_yr = None
	al_yr = None
	ds_refl_yr.close()    
	ds_al_yr.close()

# Create new xarrays of the temporal data
bare = xr.concat(all_bare, dim='TIME')
bare.name = 'Day of year of bare ice onset, selected using window strategy'
dark = xr.concat(all_dark, dim='TIME')
dark.name = 'Day of year of dark ice onset, selected using window strategy'
#alt = xr.concat(all_alt, dim='TIME')
dark_dur = xr.concat(all_dark_duration, dim='TIME')
dark_dur.name = 'Number of days in which pixel classified as dark, window strategy'
mask_nc = xr.DataArray(np.flipud(mask.r),coords=[bare.Y,bare.X], dims=['Y','X'])
mask_nc.name = 'Ice mask generated from GIMP 90m mask and resampled to MAR projection. Peripheral ice areas removed during netCDF preparation.'
dark_first = xr.concat(all_dark_first, dim='TIME')
dark_first.name = 'Day of year of dark ice onset, very first day classified as dark'
bad_dur = xr.concat(all_bad_count, dim='TIME')
bad_dur.name = 'Number of days in which pixel classified as cloudy/bad quality according to MOD10A1'
postsnow_bad_dur = xr.concat(all_postsnow_bad_count, dim='TIME')
postsnow_bad_dur.name = 'Number of days in which pixel cloudy/bad quality according to MOD10A1 after the snow has retreated'
dark_win = xr.concat(all_dark_win, dim='TIME')
dark_win.name = 'The last day in the first positive dark ice onset window.'
bare_win = xr.concat(all_bare_win, dim='TIME')
bare_win.name = 'The last day in the first positive bare ice onset window.'

## Export to netCDF
dataset = xr.Dataset({'bare':bare, 'dark':dark, 'dark_dur':dark_dur, 
	'mask':mask_nc, 'dark_first':dark_first, 'bad_dur':bad_dur, 'postsnow_bad_dur':postsnow_bad_dur, 'dark_win':dark_win, 'bare_win':bare_win})
dataset.attrs['time_period'] = '1 June to 30 August each year'
dataset.attrs['history'] = 'Generated by Andrew Tedstone using bitbucket/albedo_modis_mar/dark_ice_mapping_SW.py from MOD09GA V006 ground reflectance data and MOD10A1 Snow Albedo data'
dataset.attrs['thresh_bare_ice'] = thresh_bare_ice
dataset.attrs['thresh_dark_ice'] = thresh_dark_ice
dataset.attrs['window_bare'] = window_bare
dataset.attrs['window_dark'] = window_dark
dataset.attrs['n_bare'] = n_bare
dataset.attrs['n_dark'] = n_dark
dataset.to_netcdf(path='/scratch/physical_controls/MOD09GA.006.onset.2017.bare60.dark45.JJA.win7.b3.d3.nc')